package main

import (
	"log"
	"net"
	"os"
	"os/signal"
	"os/user"
	"strconv"
	"strings"
	"syscall"
	"time"

	"git.sr.ht/~sircmpwn/go-bare"
	"github.com/sevlyar/go-daemon"
)

// todo memguard
var PASSWORD = ""

func getUser() (*user.User, error) {
	u, err := user.Current()
	if err != nil {
		log.Println("error getting user", err)
		return nil, err
	}
	return u, nil
}

func getUserByName(username string) (*user.User, error) {
	u, err := user.Lookup(username)
	if err != nil {
		log.Println("error getting user", err)
		return nil, err
	}
	return u, nil
}

func socketName(u string) string {
	return "/tmp/eeze-agent-" + u
}

func timeOut(timeout int, sigc chan os.Signal) {
	time.Sleep(time.Duration(timeout) * time.Second)
	sigc <- syscall.SIGINT
}

func catch(sigc chan os.Signal, socket string) {
	_ = <-sigc
	os.Remove(socket)
	log.Println("bye")
	os.Exit(0)
}

func handle(conn net.Conn) error {
	r := bare.NewReader(conn)
	cmd, err := r.ReadU8()
	if err != nil {
		log.Println("error reading command", err)
		return err
	}
	switch cmd {
	case 0:
		// todo memguard
		password, err := r.ReadString()
		if err != nil {
			log.Println("error reading password to store", err)
			return err
		}
		PASSWORD = password
	case 1:
		w := bare.NewWriter(conn)
		err = w.WriteString(PASSWORD)
		if err != nil {
			log.Println("error giving password", err)
			return err
		}
	default:
		return nil
	}
	return nil
}

func parseTimeout(timeout string) int {
	i, err := strconv.ParseInt(timeout, 10, 64)
	if err != nil {
		log.Println("error parsing timeout, defualting to 300s", err)
	} else if i < 0 {
		log.Println("timeout cannot be < 0, defualting to 300s", err)
	} else {
		return int(i)
	}
	return 300
}

func main() {
	log.Println("main")
	timeout := 300
	skipArg := true
	user, err := getUser()
	if err != nil {
		log.Println("error getting user name", err)
		return
	}
	for i, arg := range os.Args {
		if skipArg {
			skipArg = false
			continue
		}
		if arg == "-u" {
			u, err := getUserByName(os.Args[i+1])
			if err != nil {
				log.Println("error getting user from name", err)
			} else {
				user = u
			}
			skipArg = true
		} else {
			timeout = parseTimeout(arg)
		}
	}
	log.Println("read args ", timeout)
	socket := socketName(user.Username)
	log.Println("socket name ", socket)

	uid, err := strconv.ParseInt(user.Uid, 10, 64)
	if err != nil {
		log.Println("error parsing uid", err)
		return
	}
	gid, err := strconv.ParseInt(user.Gid, 10, 64)
	if err != nil {
		log.Println("error parsing gid", err)
		return
	}
	context := new(daemon.Context)
	context.LogFileName = socket + ".ctx.log"
	log.Println("ctx ", context)
	child, err := context.Reborn()
	if err != nil {
		log.Println("error forking", err)
		return
	}
	log.Println("reborn")

	if child != nil {
		log.Println("waiting for socket")
		i := 10000
		for true {
			_, err := os.Stat(socket)
			if err == nil || !strings.Contains(err.Error(), "no such file or directory") {
				os.Chown(socket, int(uid), int(gid))
				os.Chmod(socket, 0600)
				os.Chmod(context.LogFileName, 0644)
				log.Println("chmoded socket")
				break
			}
			i--
		}
		log.Println("socket exists")
		return
	} else {
		defer context.Release()
		sigc := make(chan os.Signal, 1)
		log.Println("made channel")
		signal.Notify(sigc, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
		log.Println("setup signal")
		go catch(sigc, socket)
		log.Println("ran catch")
		if timeout > 0 {
			go timeOut(timeout, sigc)
			log.Println("ran timeout")
		}

		server, err := net.Listen("unix", socket)
		if err != nil {
			log.Println("error listening", err)
			return
		}
		log.Println("listening")

		log.Println("accepting")
		for {
			conn, err := server.Accept()
			if err != nil {
				log.Println("error accepting", err)
				return
			}
			go handle(conn)
		}
	}

}
